### Functions for estimation of structural crisis bargaining model

library("foreach")
library("Formula")
library("maxLik")
library("randtoolbox")
library("Rcpp")

sourceCpp("backend_main.cpp")

## f_dispute syntax:
##   id + war + win_a + win_b ~ [l_a] | [l_b] | [s_a] | [s_b]
##
## f_participant syntax:
##   id + side_a ~ x1 + x2 + ... | z1 + z2 + ...
structwar_setup <- function(f_dispute,
                            f_participant,
                            data_dispute,
                            data_participant,
                            xlev_dispute = NULL,
                            xlev_participant = NULL,
                            n_halton = 1024) {
  ## Extract dispute-level outcome vectors and regressor matrix
  f_dispute <- as.Formula(f_dispute)
  if (!identical(length(f_dispute), c(1L, 4L)))
    stop("f_dispute must have a single LHS and four RHS")
  mf_dispute <- model.frame(f_dispute, data = data_dispute, xlev = xlev_dispute)
  meta_dispute <- model.part(f_dispute, data = mf_dispute, lhs = 1)
  L_a <- model.matrix(f_dispute, data = mf_dispute, rhs = 1)
  L_b <- model.matrix(f_dispute, data = mf_dispute, rhs = 2)
  if (ncol(L_a) != ncol(L_b))
    stop("L_a and L_b must have same number of variables")
  S_a <- model.matrix(f_dispute, data = mf_dispute, rhs = 3)
  S_b <- model.matrix(f_dispute, data = mf_dispute, rhs = 4)
  if (ncol(S_a) != ncol(S_b))
    stop("S_a and S_b must have same number of variables")
  xlev_dispute <- .getXlevels(attr(mf_dispute, "terms"), mf_dispute)

  ## Sanity checks on dispute-level data
  if (ncol(meta_dispute) != 4)
    stop("f_dispute must be of the form 'id + war + win_a + win_b ~ ...'")
  dispute_level_id <- meta_dispute[, 1]
  war <- meta_dispute[, 2]
  win_a <- meta_dispute[, 3]
  win_b <- meta_dispute[, 4]
  if (any(duplicated(dispute_level_id)))
    stop("some dispute IDs are duplicated")
  if (!all(c(war, win_a, win_b) %in% 0:1))
    stop("war, win_a, and win_b must be binary")
  if (any(win_a + win_b > 1))
    stop("some disputes coded as won by both sides")

  ## Extract state-level outcome vectors and regressor matrices (after removing
  ## intercepts)
  f_participant <- as.Formula(f_participant)
  if (!identical(length(f_participant), c(1L, 2L)))
    stop("f_participant must have a single LHS and two RHS")
  f_participant <- update(f_participant, . ~ . - 1 | . - 1)
  mf_state <- model.frame(f_participant, data = data_participant, xlev = xlev_participant)
  meta_state <- model.part(f_participant, data = mf_state, lhs = 1)
  X <- model.matrix(f_participant, data = mf_state, rhs = 1)
  Z <- model.matrix(f_participant, data = mf_state, rhs = 2)
  xlev_participant <- .getXlevels(attr(mf_state, "terms"), mf_state)

  ## Sanity checks on state-level data
  if (ncol(meta_state) != 2)
    stop("f_participant must be of the form 'id + side_a ~ ...'")
  state_level_id <- meta_state[, 1]
  side_a <- meta_state[, 2]
  if (!all(side_a %in% 0:1))
    stop("side_a must be binary")

  ## Stop if any incommensurate observations
  bad_dispute <- setdiff(state_level_id, dispute_level_id)
  if (length(bad_dispute) > 0) {
    stop("IDs in state-level but not dispute-level data: ",
         paste(bad_dispute, collapse = ", "))
  }
  bad_state <- setdiff(dispute_level_id, state_level_id)
  if (length(bad_state) > 0) {
    stop("IDs in dispute-level but not state-level data: ",
         paste(bad_state, collapse = ", "))
  }

  ## Halton draws for simulated likelihood
  p_theta_a <- halton(n_halton)

  list(dispute_level_id = dispute_level_id,
       war = war,
       win_a = win_a,
       win_b = win_b,
       state_level_id = state_level_id,
       side_a = side_a,
       X = X,
       Z = Z,
       L_a = L_a,
       L_b = L_b,
       S_a = S_a,
       S_b = S_b,
       p_theta_a = p_theta_a,
       xlev_dispute = xlev_dispute,
       xlev_participant = xlev_participant)
}

constrain_rpp <- function(x) {
  x <- pmax(x, .Machine$double.xmin)
  x <- pmin(x, .Machine$double.xmax)
  x
}

extract_params <- function(est, setup, for_counterfactuals = FALSE) {
  ind_beta <- seq_len(ncol(setup$X))
  ind_gamma <- seq_len(ncol(setup$Z)) + ncol(setup$X)
  ind_coef_l <- seq_len(ncol(setup$L_a)) + ncol(setup$X) + ncol(setup$Z)
  ind_coef_s <- seq_len(ncol(setup$S_a)) + ncol(setup$X) + ncol(setup$Z) + ncol(setup$L_a)

  beta <- est[ind_beta]
  gamma <- est[ind_gamma]
  coef_l <- est[ind_coef_l]
  coef_s <- est[ind_coef_s]

  ## Calculate individual cost-effectiveness ratios
  ratio <- constrain_rpp(drop(exp(setup$Z %*% gamma - setup$X %*% beta)))

  ## Calculate distributional parameters
  loc_a <- drop(setup$L_a %*% coef_l)
  loc_b <- drop(setup$L_b %*% coef_l)
  scl_a <- constrain_rpp(drop(exp(setup$S_a %*% coef_s)))
  scl_b <- constrain_rpp(drop(exp(setup$S_b %*% coef_s)))

  ans <- list(ratio = ratio,
              loc_a = loc_a,
              loc_b = loc_b,
              scl_a = scl_a,
              scl_b = scl_b)

  if (for_counterfactuals) {
    ans$efx <- drop(exp(setup$X %*% beta))
    ans$cost <- drop(exp(setup$Z %*% gamma))
  }

  ans
}

loglik_structwar_ac <- function(est, setup) {
  parms <- extract_params(est, setup)
  ll <- loglik_backend(dispute_level_id = setup$dispute_level_id,
                       war = setup$war,
                       win_a = setup$win_a,
                       win_b = setup$win_b,
                       state_level_id = setup$state_level_id,
                       side_a = setup$side_a,
                       p_theta_a = setup$p_theta_a,
                       ratio = parms$ratio,
                       loc_a = parms$loc_a,
                       loc_b = parms$loc_b,
                       scl_a = parms$scl_a,
                       scl_b = parms$scl_b,
                       xmax = .Machine$double.xmax)

  ## Treat NaNs as -Inf
  ll <- ifelse(is.nan(ll), -Inf, ll)

  ll
}

grad_structwar_ac <- function(est, setup) {
  parms <- extract_params(est, setup)
  g_ll <- grad_backend(dispute_level_id = setup$dispute_level_id,
                       war = setup$war,
                       win_a = setup$win_a,
                       win_b = setup$win_b,
                       state_level_id = setup$state_level_id,
                       side_a = setup$side_a,
                       p_theta_a = setup$p_theta_a,
                       ratio = parms$ratio,
                       loc_a = parms$loc_a,
                       loc_b = parms$loc_b,
                       scl_a = parms$scl_a,
                       scl_b = parms$scl_b,
                       X = setup$X,
                       Z = setup$Z,
                       L_a = setup$L_a,
                       L_b = setup$L_b,
                       S_a = setup$S_a,
                       S_b = setup$S_b,
                       xmax = .Machine$double.xmax)

  g_ll
}

start_vals <- function(setup, init = NULL) {
  cf_names <- c(
    if (ncol(setup$X) > 0) paste("beta", colnames(setup$X), sep = ":"),
    if (ncol(setup$Z) > 0) paste("gamma", colnames(setup$Z), sep = ":"),
    paste("loc", colnames(setup$L_a), sep = ":"),
    paste("scl", colnames(setup$S_a), sep = ":")
  )
  cf <- rep(0.0, length(cf_names))
  names(cf) <- cf_names

  if (!is.null(init)) {
    overlap <- intersect(names(cf), names(init))
    cf[overlap] <- init[overlap]
  }

  cf
}

make_parscale <- function(setup, tol = 1e-8) {
  ans <- c(
    apply(setup$X, 2, sd),
    apply(setup$Z, 2, sd),
    apply(rbind(setup$L_a, setup$L_b), 2, sd),
    apply(rbind(setup$S_a, setup$S_b), 2, sd)
  )
  ans[abs(ans) < tol] <- 1
  1 / ans
}

bfgs_structwar_ac <- function(setup,
                              init = NULL,
                              scale = FALSE,
                              reltol = sqrt(.Machine$double.eps),
                              iterlim = 1000,
                              printLevel = 1,
                              useGrad = TRUE,
                              finalHessian = FALSE) {
  start <- start_vals(setup = setup, init = init)
  parscale <- if (scale) make_parscale(setup) else rep(1.0, length(start))

  fit <- maxLik(logLik = loglik_structwar_ac,
                grad = if (useGrad) grad_structwar_ac else NULL,
                start = start,
                method = "BFGS",
                control = list(
                  reltol = reltol,
                  iterlim = iterlim,
                  printLevel = printLevel
                ),
                finalHessian = finalHessian,
                parscale = parscale,
                setup = setup)

  structure(fit, startvals = start)
}

## Wrapper function for convenient estimation
est_structwar_ac <- function(f_dispute,
                             f_participant,
                             data_dispute,
                             data_participant,
                             n_halton = 1024,
                             coef_only = FALSE,
                             init = NULL,
                             init_fallback = NULL,
                             scale = FALSE,
                             reltol = sqrt(.Machine$double.eps),
                             iterlim = 1000,
                             printLevel = 1,
                             useGrad = TRUE,
                             finalHessian = FALSE)
{
  ## If the data arguments are lists of imputed datasets, automatically run
  ## the estimation on each one and return a list of the results
  if (inherits(data_dispute, "list")) {
    stopifnot(inherits(data_participant, "list"))
    stopifnot(length(data_participant) == length(data_dispute))
    if (!is.null(init)) {
      stopifnot(inherits(init, "list"))
      stopifnot(length(init) == length(data_dispute))
    }
    ans <- foreach (i = seq_along(data_dispute)) %do% {
      est_structwar_ac(f_dispute = f_dispute,
                       f_participant = f_participant,
                       data_dispute = data_dispute[[i]],
                       data_participant = data_participant[[i]],
                       n_halton = n_halton,
                       coef_only = coef_only,
                       init = init[[i]],
                       init_fallback = init_fallback,
                       scale = scale,
                       reltol = reltol,
                       iterlim = iterlim,
                       printLevel = printLevel,
                       useGrad = useGrad,
                       finalHessian = finalHessian)
    }
    attr(ans, "scale") <- scale
    attr(ans, "reltol") <- reltol
    attr(ans, "iterlim") <- iterlim
    return(ans)
  }

  ## Get model object together
  f_dispute <- as.Formula(f_dispute)
  f_participant <- as.Formula(f_participant)
  setup <- structwar_setup(f_dispute = f_dispute,
                           f_participant = f_participant,
                           data_dispute = data_dispute,
                           data_participant = data_participant,
                           n_halton = n_halton)

  ## Double-check there's no missing data
  if (any(sapply(setup, anyNA)))
    stop("Missing data present")

  ## BFGS estimation, falling back on all-zero starting values if the supplied
  ## values result in an error
  fit <- tryCatch(bfgs_structwar_ac(setup = setup,
                                    init = init,
                                    scale = scale,
                                    reltol = reltol,
                                    iterlim = iterlim,
                                    printLevel = printLevel,
                                    useGrad = useGrad,
                                    finalHessian = finalHessian),
                  error = identity)
  if (inherits(fit, "simpleError") && !is.null(init_fallback)) {
    fit <- bfgs_structwar_ac(setup = setup,
                             init = init_fallback,
                             scale = scale,
                             reltol = reltol,
                             iterlim = iterlim,
                             printLevel = printLevel,
                             useGrad = useGrad,
                             finalHessian = finalHessian)
  }
  if (inherits(fit, "simpleError") && !is.null(init)) {
    fit <- bfgs_structwar_ac(setup = setup,
                             init = NULL,
                             scale = scale,
                             reltol = reltol,
                             iterlim = iterlim,
                             printLevel = printLevel,
                             useGrad = useGrad,
                             finalHessian = finalHessian)
  }

  if (coef_only) {
    ans <- coef(fit)
  } else {
    ans <- list(fit = fit,
                f_dispute = f_dispute,
                f_participant = f_participant,
                data_dispute = data_dispute,
                data_participant = data_participant,
                n_halton = n_halton,
                xlev_dispute = setup$xlev_dispute,
                xlev_participant = setup$xlev_participant,
                loglik_indiv = loglik_structwar_ac(est = coef(fit),
                                                   setup = setup))
  }

  ans
}
